{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SARS-CoV-2 Knowledge Graph "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time\n",
    "import re\n",
    "import math\n",
    "import random\n",
    "import pickle\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import metrics \n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.nn.modules import Module\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from torch_geometric.data import Data, DataLoader\n",
    "from torch_geometric.utils import train_test_split_edges\n",
    "from torch_geometric.utils import add_remaining_self_loops, add_self_loops\n",
    "from torch_geometric.utils import to_undirected\n",
    "from torch_geometric.nn import GCNConv, SAGEConv,GAE, VGAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path='data/'\n",
    "exp_id='v0'\n",
    "device_id=0 #'cpu' if CPU, device number if GPU\n",
    "embedding_size=128"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load preprocessed files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "le=pickle.load(open(data_path+'LabelEncoder_'+exp_id+'.pkl', 'rb'))\n",
    "edge_index=pickle.load(open(data_path+'edge_index_'+exp_id+'.pkl','rb'))\n",
    "node_feature_np=pickle.load(open(data_path+'node_feature_'+exp_id+'.pkl','rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_feature=torch.tensor(node_feature_np, dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge=torch.tensor(edge_index[['node1', 'node2']].values, dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_attr_dict={'gene-drug':0,'gene-gene':1,'bait-gene':2, 'gene-phenotype':3, 'drug-phenotype':4}\n",
    "edge_index['type']=edge_index['type'].apply(lambda x: edge_attr_dict[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    16972\n",
       "1    13423\n",
       "3     1401\n",
       "4      935\n",
       "2      330\n",
       "Name: type, dtype: int64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_index['type'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_attr=torch.tensor(edge_index['type'].values,dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = Data(x=node_feature,\n",
    "            edge_index=edge.t().contiguous(),\n",
    "            edge_attr=edge_attr\n",
    "           )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(400, 10624, 33061)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.num_features, data.num_nodes,data.num_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([33061])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_attr.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(True, True)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.contains_isolated_nodes(), data.is_directed()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1):\n",
    "    r\"\"\"Splits the edges of a :obj:`torch_geometric.data.Data` object\n",
    "    into positive and negative train/val/test edges, and adds attributes of\n",
    "    `train_pos_edge_index`, `train_neg_adj_mask`, `val_pos_edge_index`,\n",
    "    `val_neg_edge_index`, `test_pos_edge_index`, and `test_neg_edge_index`\n",
    "    to :attr:`data`.\n",
    "\n",
    "    Args:\n",
    "        data (Data): The data object.\n",
    "        val_ratio (float, optional): The ratio of positive validation\n",
    "            edges. (default: :obj:`0.05`)\n",
    "        test_ratio (float, optional): The ratio of positive test\n",
    "            edges. (default: :obj:`0.1`)\n",
    "\n",
    "    :rtype: :class:`torch_geometric.data.Data`\n",
    "    \"\"\"\n",
    "\n",
    "    assert 'batch' not in data  # No batch-mode.\n",
    "\n",
    "    num_nodes = data.num_nodes\n",
    "    row, col = data.edge_index\n",
    "    #data.edge_index = None\n",
    "    attr = data.edge_attr\n",
    "\n",
    "    # Return upper triangular portion.\n",
    "    #mask = row < col\n",
    "    #row, col = row[mask], col[mask]\n",
    "\n",
    "    n_v = int(math.floor(val_ratio * row.size(0)))\n",
    "    n_t = int(math.floor(test_ratio * row.size(0)))\n",
    "\n",
    "    # Positive edges.\n",
    "    perm = torch.randperm(row.size(0))\n",
    "    row, col = row[perm], col[perm]\n",
    "    attr=attr[perm]\n",
    "\n",
    "    r, c = row[:n_v], col[:n_v]\n",
    "    data.val_pos_edge_index = torch.stack([r, c], dim=0)\n",
    "    data.val_pos_edge_attr = attr[:n_v]\n",
    "    \n",
    "    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]\n",
    "    data.test_pos_edge_index = torch.stack([r, c], dim=0)\n",
    "    data.test_post_edge_attr = attr[n_v:n_v + n_t]\n",
    "\n",
    "    r, c = row[n_v + n_t:], col[n_v + n_t:]\n",
    "    data.train_pos_edge_index = torch.stack([r, c], dim=0)\n",
    "    data.train_pos_edge_attr = attr[n_v+n_t:]\n",
    "\n",
    "    # Negative edges.\n",
    "    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)\n",
    "    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)\n",
    "    neg_adj_mask[row, col] = 0\n",
    "\n",
    "    neg_row, neg_col = neg_adj_mask.nonzero().t()\n",
    "    perm = random.sample(range(neg_row.size(0)),\n",
    "                         min(n_v + n_t, neg_row.size(0)))\n",
    "    perm = torch.tensor(perm)\n",
    "    perm = perm.to(torch.long)\n",
    "    neg_row, neg_col = neg_row[perm], neg_col[perm]\n",
    "\n",
    "    neg_adj_mask[neg_row, neg_col] = 0\n",
    "    data.train_neg_adj_mask = neg_adj_mask\n",
    "\n",
    "    row, col = neg_row[:n_v], neg_col[:n_v]\n",
    "    data.val_neg_edge_index = torch.stack([row, col], dim=0)\n",
    "\n",
    "    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]\n",
    "    data.test_neg_edge_index = torch.stack([row, col], dim=0)\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_split=train_test_split_edges(data, test_ratio=0.1, val_ratio=0)\n",
    "x,train_pos_edge_index,train_pos_edge_attr = data_split.x.to(device), data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1    22599\n",
       "0    15295\n",
       "3     1259\n",
       "4      856\n",
       "2      290\n",
       "dtype: int64"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.Series(train_pos_edge_attr.cpu().numpy()).value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define VGAE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder_VGAE(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, isClassificationTask=False):\n",
    "        super(Encoder_VGAE, self).__init__()\n",
    "        self.isClassificationTask=isClassificationTask\n",
    "        self.conv_gene_drug=  SAGEConv(in_channels, 2*out_channels, )\n",
    "        self.conv_gene_gene = SAGEConv(in_channels, 2*out_channels, )\n",
    "        self.conv_bait_gene = SAGEConv(in_channels, 2*out_channels, )\n",
    "        self.conv_gene_phenotype = SAGEConv(in_channels, 2*out_channels, )\n",
    "        self.conv_drug_phenotype = SAGEConv(in_channels, 2*out_channels)\n",
    "\n",
    "        self.bn = nn.BatchNorm1d(5*2*out_channels)\n",
    "        #variational encoder\n",
    "        self.conv_mu = SAGEConv(5*2*out_channels, out_channels, )\n",
    "        self.conv_logvar = SAGEConv(5*2*out_channels, out_channels,)\n",
    "\n",
    "    def forward(self,x,edge_index,edge_attr):\n",
    "        \n",
    "        x = F.dropout(x, training=self.training)\n",
    "        \n",
    "        index_gene_drug=(edge_attr==0).nonzero().reshape(1,-1)[0]\n",
    "        edge_index_gene_drug=edge_index[:, index_gene_drug]\n",
    "        \n",
    "        index_gene_gene=(edge_attr==1).nonzero().reshape(1,-1)[0]\n",
    "        edge_index_gene_gene=edge_index[:, index_gene_gene]\n",
    "        \n",
    "        index_bait_gene=(edge_attr==2).nonzero().reshape(1,-1)[0]\n",
    "        edge_index_bait_gene=edge_index[:, index_bait_gene]\n",
    "        \n",
    "        index_gene_phenotype=(edge_attr==3).nonzero().reshape(1,-1)[0]\n",
    "        edge_index_gene_phenotype=edge_index[:, index_gene_phenotype]\n",
    "        \n",
    "        index_drug_phenotype=(edge_attr==4).nonzero().reshape(1,-1)[0]\n",
    "        edge_index_drug_phenotype=edge_index[:, index_drug_phenotype]\n",
    "        \n",
    "        \n",
    "        x_gene_drug = F.dropout(F.relu(self.conv_gene_drug(x,edge_index_gene_drug)), p=0.5, training=self.training, )\n",
    "        x_gene_gene = F.dropout(F.relu(self.conv_gene_gene(x,edge_index_gene_gene)), p=0.5, training=self.training)\n",
    "        x_bait_gene = F.dropout(F.relu(self.conv_bait_gene(x,edge_index_bait_gene)), p=0.1, training=self.training)\n",
    "        x_gene_phenotype = F.dropout(F.relu(self.conv_gene_phenotype(x,edge_index_gene_phenotype)), training=self.training)\n",
    "        x_drug_phenotype = F.dropout(F.relu(self.conv_drug_phenotype(x,edge_index_drug_phenotype)), training=self.training)\n",
    "\n",
    "        x=self.bn(torch.cat([x_gene_drug,x_gene_gene,x_bait_gene,x_gene_phenotype,x_drug_phenotype],dim=1))        \n",
    "        \n",
    "        return self.conv_mu(x,edge_index), self.conv_logvar(x,edge_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "model=VGAE(Encoder_VGAE(node_feature.shape[1], embedding_size)).to(device)\n",
    "optimizer=torch.optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    z = model.encode(x, train_pos_edge_index, train_pos_edge_attr)\n",
    "    loss = model.recon_loss(z, train_pos_edge_index)\n",
    "    loss = loss + (1 / data.num_nodes) * model.kl_loss()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    print(loss.item())\n",
    "    \n",
    "def test(pos_edge_index, neg_edge_index):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        z=model.encode(x, train_pos_edge_index,train_pos_edge_attr)\n",
    "    return model.test(z, pos_edge_index, neg_edge_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#DRKG's accuracy for comparison\n",
    "model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, 10):\n",
    "    train()\n",
    "    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)\n",
    "    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Node embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))\n",
    "z_np = z.squeeze().detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the new embedding "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(z_np, open(data_path+'node_embedding_'+exp_id+'.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the torch model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), data_path+'VAE_encoders_'+exp_id+'.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VGAE(\n",
       "  (encoder): Encoder_VGAE(\n",
       "    (conv_gene_drug): SAGEConv(400, 256)\n",
       "    (conv_gene_gene): SAGEConv(400, 256)\n",
       "    (conv_bait_gene): SAGEConv(400, 256)\n",
       "    (conv_gene_phenotype): SAGEConv(400, 256)\n",
       "    (conv_drug_phenotype): SAGEConv(400, 256)\n",
       "    (bn): BatchNorm1d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (conv_mu): SAGEConv(1280, 128)\n",
       "    (conv_logvar): SAGEConv(1280, 128)\n",
       "  )\n",
       "  (decoder): InnerProductDecoder()\n",
       ")"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(torch.load(data_path+'VAE_encoders_'+exp_id+'.pkl'))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ranking model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier\n",
    "from sklearn.model_selection import train_test_split, cross_val_score\n",
    "from sklearn import metrics\n",
    "from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "topk=300\n",
    "types=np.array([item.split('_')[0] for item in le.classes_ ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load drugs under clinical trial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "#label\n",
    "trials=pd.read_excel(data_path+'literature-mining/All_trails_5_24.xlsx',header=1,index_col=0)\n",
    "trials_drug=set([drug.strip().upper() for lst in trials.loc[trials['study_category'].apply(lambda x: 'drug' in x.lower()),'intervention'].apply(lambda x: re.split(r'[+|/|,]',x.replace(' vs. ', '/').replace(' vs ', '/').replace(' or ', '/').replace(' with and without ', '/').replace(' /wo ', '/').replace(' /w ', '/').replace(' and ', '/').replace(' - ', '/').replace(' (', '/').replace(') ', '/'))).values for drug in lst])\n",
    "drug_labels=[1 if drug.split('_')[1] in trials_drug else 0 for drug in le.classes_[types=='drug'] ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BPR loss NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed=70\n",
    "indices = np.arange(len(drug_labels))\n",
    "X_train, X_test, y_train, y_test,indices_train,indices_test=train_test_split(z_np[types=='drug'],drug_labels,indices, test_size=0.5,random_state=seed,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Variable wrapping for torch.tensor\n",
    "_X_train, _y_train=Variable(torch.tensor(X_train,dtype=torch.float).to(device)), Variable(torch.tensor(y_train,dtype=torch.float).to(device))\n",
    "_X_test, _y_test=Variable(torch.tensor(X_test,dtype=torch.float).to(device)), Variable(torch.tensor(y_test,dtype=torch.float).to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "    def __init__(self,embedding_dim=embedding_size):\n",
    "        super(Classifier, self).__init__() \n",
    "        self.fc1=nn.Linear(embedding_dim,embedding_dim)\n",
    "        self.fc2=nn.Linear(embedding_dim,1)\n",
    "        self.bn=nn.BatchNorm1d(embedding_dim)\n",
    "    def forward(self, x):\n",
    "        residual1 = x\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x= self.bn(F.dropout(F.relu(self.fc1(x)),training=self.training))\n",
    "        x += residual1  \n",
    "        return self.fc2(x)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import BatchSampler, WeightedRandomSampler\n",
    "class BPRLoss(nn.Module):\n",
    "    def __init__(self, num_neg_samples):\n",
    "        super(BPRLoss, self).__init__()\n",
    "        self.num_neg_samples=num_neg_samples\n",
    "    \n",
    "    def forward(self, output, label):\n",
    "        positive_output=output[label==1]\n",
    "        negative_output=output[label!=1]\n",
    "        \n",
    "        #negative sample proportional to the high values\n",
    "        negative_sampler=WeightedRandomSampler(negative_output-min(negative_output), num_samples=self.num_neg_samples*len(positive_output),replacement=True)\n",
    "        negative_sample_output=negative_output[torch.tensor(list(BatchSampler(negative_sampler, batch_size=len(positive_output),drop_last=True)),dtype=torch.long).t()]\n",
    "        return -(positive_output.view(-1,1)-negative_sample_output).sigmoid().log().mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf=Classifier(embedding_size).to(device)\n",
    "optimizer=torch.optim.Adam(clf.parameters())\n",
    "criterion=BPRLoss(num_neg_samples=15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training loss 0.41865915060043335\n",
      "test loss 0.6383747458457947\n",
      "training loss 0.3675549030303955\n",
      "test loss 0.6116868853569031\n",
      "training loss 0.4384680986404419\n",
      "test loss 0.6033534407615662\n",
      "training loss 0.4908576011657715\n",
      "test loss 0.6275379061698914\n",
      "training loss 0.36719879508018494\n",
      "test loss 0.6025533080101013\n",
      "training loss 0.3904884457588196\n",
      "test loss 0.6043612957000732\n",
      "training loss 0.40178656578063965\n",
      "test loss 0.5972631573677063\n",
      "training loss 0.39521604776382446\n",
      "test loss 0.5915830135345459\n",
      "training loss 0.38820844888687134\n",
      "test loss 0.6791005730628967\n",
      "training loss 0.40899592638015747\n",
      "test loss 0.6104799509048462\n",
      "training loss 0.37142685055732727\n",
      "test loss 0.6891355514526367\n",
      "training loss 0.37280625104904175\n",
      "test loss 0.6197623610496521\n",
      "training loss 0.3570763170719147\n",
      "test loss 0.6380096673965454\n",
      "training loss 0.3555067777633667\n",
      "test loss 0.656690776348114\n",
      "training loss 0.3914637863636017\n",
      "test loss 0.6770980954170227\n",
      "training loss 0.4530521035194397\n",
      "test loss 0.5766391754150391\n",
      "training loss 0.3638915419578552\n",
      "test loss 0.56634920835495\n",
      "training loss 0.37793928384780884\n",
      "test loss 0.7055246829986572\n",
      "training loss 0.37624314427375793\n",
      "test loss 0.5402485728263855\n",
      "training loss 0.38515210151672363\n",
      "test loss 0.5837911367416382\n",
      "training loss 0.38353431224823\n",
      "test loss 0.5742397904396057\n",
      "training loss 0.35172799229621887\n",
      "test loss 0.7139874696731567\n",
      "training loss 0.3436571955680847\n",
      "test loss 0.6214970946311951\n",
      "training loss 0.39446115493774414\n",
      "test loss 0.6644965410232544\n",
      "training loss 0.38150277733802795\n",
      "test loss 0.6938989758491516\n",
      "training loss 0.37255266308784485\n",
      "test loss 0.6802477836608887\n",
      "training loss 0.3785788118839264\n",
      "test loss 0.6574035286903381\n",
      "training loss 0.41502097249031067\n",
      "test loss 0.6154575943946838\n",
      "training loss 0.4060690402984619\n",
      "test loss 0.6205570101737976\n",
      "training loss 0.34869033098220825\n",
      "test loss 0.6109451651573181\n"
     ]
    }
   ],
   "source": [
    "best_auprc=0\n",
    "for epoch in range(30):\n",
    "    clf.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = clf(_X_train)\n",
    "    loss=criterion(out.squeeze(), _y_train)\n",
    "    loss.backward()\n",
    "    optimizer.step()   \n",
    "    print('training loss',loss.item())\n",
    "\n",
    "    clf.eval()\n",
    "    print('test loss', criterion(clf(_X_test).squeeze(), _y_test).item())\n",
    "    prob=torch.sigmoid(clf(_X_test)).cpu().detach().numpy().squeeze()\n",
    "    auprc=metrics.average_precision_score(y_test,prob)\n",
    "    if auprc>best_auprc:\n",
    "        best_auproc=auprc\n",
    "        torch.save(clf, data_path+'nn_clf-temp.pt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.load_state_dict(torch.load(data_path+'nn_clf.pt').state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Compute AUC\n",
    "clf.eval()\n",
    "\n",
    "prob=torch.sigmoid(clf(_X_test)).cpu().detach().numpy().squeeze()\n",
    "print(\"AUROC\", metrics.roc_auc_score(y_test,prob))\n",
    "print(\"AUPRC\", metrics.average_precision_score(y_test,prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_items_idx=np.argsort(-clf(torch.tensor(z_np[types=='drug'],dtype=torch.float).to(device)).squeeze().detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logit AUROC 0.6800655093350803\n",
      "Logit AUPRC 0.06040678123866987\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
      "  FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "clf=LogisticRegression().fit(X_train,y_train)\n",
    "print(\"Logit AUROC\", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))\n",
    "print(\"Logit AUPRC\", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "XGBoost AUROC 0.7018932197838192\n",
      "XGBoost AUPRC 0.08359722729002399\n"
     ]
    }
   ],
   "source": [
    "clf=GradientBoostingClassifier().fit(X_train,y_train)\n",
    "print(\"XGBoost AUROC\", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))\n",
    "print(\"XGBoost AUPRC\", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rf AUROC 0.6161415001637733\n",
      "rf AUPRC 0.09399119591651842\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n",
      "  \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "clf=RandomForestClassifier().fit(X_train,y_train)\n",
    "print(\"rf AUROC\", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))\n",
    "print(\"rf AUPRC\", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "svm AUROC 0.6915296429741238\n",
      "svm AUPRC 0.11493301290344271\n"
     ]
    }
   ],
   "source": [
    "clf=make_pipeline(StandardScaler(), SVC(gamma='auto',probability=True)).fit(X_train,y_train)\n",
    "print(\"svm AUROC\", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))\n",
    "print(\"svm AUPRC\", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#top_items_idx=np.argsort(-clf.predict_proba(z_np[types=='drug'])[:,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the high-ranked drugs into csv file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "topk_drugs=pd.DataFrame([(rank, drug.split('_')[1]) for rank,drug in enumerate(le.inverse_transform((types=='drug').nonzero()[0][top_items_idx])[:topk+1])], columns=['rank', 'drug'])\n",
    "topk_drugs['under_trials']=topk_drugs['drug'].isin(trials_drug).astype(int)\n",
    "topk_drugs['is_used_in_training']=topk_drugs['drug'].isin(np.array([drug.split('_')[1] for drug in le.classes_[types=='drug']])[indices_train]).astype(int)\n",
    "topk_drugs.to_csv('top300_drugs.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
